import os
import time
import numpy as np
import pandas as pd
import multiprocessing as mp

# Globals (initialized in init_params)
n = None
s = None
H = None
num_runs = None
num_processes = None
seed = None

def scaled_weights(shape):
    """W ~ N(0,1/n)"""
    return np.random.randn(*shape) / np.sqrt(n)

def init_params(params):
    global n, s, H, num_runs, num_processes, seed
    n            = params['n']
    s            = params['s']
    H            = params['H']
    num_runs     = params['num_runs']
    num_processes= params['num_processes']
    seed         = params['seed']
    np.random.seed(seed)

def single_run_sqrt(_):
    h = np.random.randn(n)
    X = scaled_weights((s, n, n)).reshape(s, n, n) @ h
    Wq = scaled_weights((n, n))
    Wk = scaled_weights((n, n))
    Q  = X @ Wq
    K  = X @ Wk
    G  = Q.dot(K.T) / np.sqrt(n)
    return G[0, 0]

def single_run_n(_):
    h = np.random.randn(n)
    X = scaled_weights((s, n, n)).reshape(s, n, n) @ h
    Wq = scaled_weights((n, n))
    Wk = scaled_weights((n, n))
    Q  = X @ Wq
    K  = X @ Wk
    G  = Q.dot(K.T) / n
    return G[0, 0]

def simulate_and_save(params, sim_fn, output_csv):
    # Run num_runs of sim_fn in parallel and save the scalar results to CSV
    init_params(params)
    t0 = time.time()
    with mp.Pool(processes=num_processes, initializer=init_params, initargs=(params,)) as pool:
        vals = pool.map(sim_fn, range(num_runs))
    df = pd.DataFrame({'value': vals})
    df.to_csv(output_csv, index=False)
    print(f"{output_csv} written in {time.time()-t0:.2f}s")

if __name__ == '__main__':
    params = {
        'n': 256,
        's': 4,
        'H': 2,
        'num_runs': 300000,
        'num_processes': 18,
        'seed': 0
    }

    # 1/√n scaling
    simulate_and_save(params, single_run_sqrt, 'score_sqrt.csv')

    # 1/n scaling
    simulate_and_save(params, single_run_n, 'score_n.csv')
